{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "from transformers import PreTrainedTokenizerFast, PhiForCausalLM, TrainingArguments\n",
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "import time\n",
    "\n",
    "from trl import SFTTrainer, DataCollatorForCompletionOnlyLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 这个版本SFT使用的是SFTTrainer，map数据集比较慢，仅供参考"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sft_file = './data/sft_train_data.parquet'\n",
    "tokenizer_dir = './model_save/tokenizer/'\n",
    "sft_from_checkpoint_file = './model_save/pre/'\n",
    "model_save_dir = './model_save/sft/'\n",
    "max_seq_len = 320"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(path='parquet', data_files=sft_file, split='train', cache_dir='.cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = dataset[0:2]\n",
    "print(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)\n",
    "len(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def formatting_prompts_func(example: list[dict]) -> list[str]:\n",
    "    batch_txt = []\n",
    "    for i in range(len(example['instruction'])):\n",
    "        text = f\"[BOS]##提问:\\n{example['instruction'][i]}\\n##回答:\\n{example['output'][i]}[EOS]\"\n",
    "        batch_txt.append(text)\n",
    "        \n",
    "    return batch_txt\n",
    "\n",
    "formatting_prompts_func(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "instruction_template = \"##提问:\"\n",
    "response_template = \"##回答:\"\n",
    "collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer([instruction_template, response_template])['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model = PhiForCausalLM.from_pretrained(sft_from_checkpoint_file)\n",
    "\n",
    "model_size = sum(t.numel() for t in model.parameters())\n",
    "print(f\"Phi2 size: {model_size / 1000**2:.2f}M parameters\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=model_save_dir,\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1000,\n",
    "    learning_rate=5e-5,\n",
    "    save_steps=2000,\n",
    "    save_total_limit=3,\n",
    "    report_to='tensorboard',\n",
    "    optim=\"adafactor\",\n",
    "    bf16=True,\n",
    "    logging_steps=10,\n",
    "    log_level='info',\n",
    "    logging_first_step=True,\n",
    ")\n",
    "trainer = SFTTrainer(\n",
    "    model,\n",
    "    args=args,\n",
    "    train_dataset=dataset,\n",
    "    formatting_func=formatting_prompts_func,\n",
    "    data_collator=collator,\n",
    "    tokenizer=tokenizer,\n",
    "    max_seq_length=384, #  eos\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train(\n",
    "    # resume_from_checkpoint=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_log = pd.DataFrame(trainer.state.log_history)\n",
    "loss_log.to_csv(f\"./logs/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv\")\n",
    "\n",
    "\n",
    "trainer.save_model(model_save_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
